{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████| 338M/338M [00:16<00:00, 20.9MiB/s]\n"
     ]
    }
   ],
   "source": [
    "import pickle\n",
    "import sys\n",
    "sys.path.append(\"../..\")\n",
    "\n",
    "from aips.spark.dataframe import from_csv\n",
    "from pyspark.conf import SparkConf\n",
    "from pyspark.sql import SparkSession\n",
    "import os\n",
    "from itertools import groupby\n",
    "import aips.data_loaders.movies as movies\n",
    "import requests\n",
    "import shutil\n",
    "from tqdm import tqdm\n",
    "import tarfile\n",
    "import clip\n",
    "import torch\n",
    "\n",
    "remote_image_path = \"http://image.tmdb.org/t/p/w780/\"\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "model, preprocess = clip.load(\"ViT-B/32\", device=\"cpu\")\n",
    "\n",
    "conf = SparkConf()\n",
    "conf.set(\"spark.driver.memory\", \"8g\")\n",
    "conf.set(\"spark.executor.memory\", \"8g\")\n",
    "conf.set(\"spark.dynamicAllocation.enabled\", \"true\")\n",
    "conf.set(\"spark.dynamicAllocation.executorMemoryOverhead\", \"8g\")\n",
    "spark = SparkSession.builder.appName(\"AIPS\").config(conf=conf).getOrCreate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "def dump(data, cache_name):\n",
    "    cache_file_name = f\"../../data/tmdb/{cache_name}.pickle\"\n",
    "    os.makedirs(os.path.dirname(cache_file_name), exist_ok=True)\n",
    "    with open(cache_file_name, \"wb\") as fd:\n",
    "        pickle.dump(data, fd)\n",
    "\n",
    "def dump_dataframe(movies_dataframe, cache_name):\n",
    "    movies = movies_dataframe.rdd.map(lambda row: row.asDict()).collect()\n",
    "    dump(movies, cache_name=cache_name)\n",
    "\n",
    "def read(cache_name):\n",
    "    cache_file_name = f\"../../data/tmdb/{cache_name}.pickle\"\n",
    "    with open(cache_file_name, \"rb\") as fd:\n",
    "        return pickle.load(fd)\n",
    "\n",
    "def compress_file(file_name, root_path=\"../../data/tmdb/\"):    \n",
    "    with tarfile.open(f\"{root_path}{file_name}.tgz\", \"w:gz\") as tar:\n",
    "        tar.add(f\"{root_path}{file_name}.pickle\", arcname=f\"{file_name}.pickle\")\n",
    "\n",
    "def load_image(file_name, load_remote, log=False):\n",
    "    full_local_path = f\"../../data/tmdb/large_movie_images/{file_name}.jpg\"    \n",
    "    try:\n",
    "        exists = os.path.exists(full_local_path)\n",
    "        if log: print(f\"{full_local_path} exists: {exists}\")\n",
    "        if not exists and load_remote:\n",
    "            remote_image_url = f\"{remote_image_path}{file_name}.jpg\"\n",
    "            response = requests.get(remote_image_url, stream=True)\n",
    "            with open(full_local_path, 'wb') as out_file:\n",
    "                shutil.copyfileobj(response.raw, out_file)\n",
    "            del response\n",
    "            if log: print(f\"Downloaded: {full_local_path}\")\n",
    "        image = Image.open(full_local_path)\n",
    "        if log: print(\"File Found\")\n",
    "        return image\n",
    "    except:\n",
    "        if log: print(f\"No Image Available {full_local_path}\")\n",
    "        return []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def download_movie_images(movie_file=\"tmdb_movies\"):\n",
    "    movie_data = read(movie_file)\n",
    "    for movie in tqdm(movie_data, total=len(movie_data)):\n",
    "        if movie[\"movie_image_ids\"]:\n",
    "            for image_id in movie[\"movie_image_ids\"].split(\",\"):                \n",
    "                load_image(image_id, True)\n",
    "                \n",
    "def compute_image_embedding(image_id, log=False):\n",
    "    try:        \n",
    "        image = load_image(image_id, True, log=log)\n",
    "        inputs = preprocess(image).unsqueeze(0).to(\"cpu\")\n",
    "        return model.encode_image(inputs).tolist()[0]\n",
    "    except Exception as e:\n",
    "        if log: print(f\"Image processing exception: {e}\")\n",
    "        return []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_movie_data_with_image_ids(cache_name=\"tmdb_movies\", log=False):\n",
    "    \"Merges movie image ids with base movie data, generating a movie \"\n",
    "    title_movie_map_file = \"../../data/tmdb/movie_data.csv\"\n",
    "    dataframe = from_csv(title_movie_map_file)\n",
    "    movie_image_ids = {}\n",
    "    for k, g in groupby([row.asDict() for row in dataframe.collect()],\n",
    "                        lambda m: m[\"tooltip\"].lower()):\n",
    "        ids = [m[\"path\"].split(\"/\")[-1][:-4] for m in g]\n",
    "        movie_image_ids[k] = ids\n",
    "    \n",
    "    movie_dataframe = movies.load_dataframe(\"../../data/tmdb.json\", movie_image_ids)\n",
    "    dump_dataframe(movie_dataframe, cache_name)\n",
    "    compress_file(cache_name)\n",
    "\n",
    "def generate_image_embeddings_data(cache_name=\"movies_with_image_embeddings\", log=False):\n",
    "    \"Calculates and caches embeddings for all large movie images\"\n",
    "    movie_data = read(\"tmdb_movies\")\n",
    "    movie_ids, titles, image_ids, embeddings = [], [], [], []\n",
    "    for movie in tqdm(movie_data, total=len(movie_data)):\n",
    "        if movie[\"movie_image_ids\"]:\n",
    "            for image_id in movie[\"movie_image_ids\"].split(\",\"):\n",
    "                embedding = compute_image_embedding(image_id, log=log)\n",
    "                print(embedding)\n",
    "                if embedding:\n",
    "                    movie_ids.append(movie[\"id\"])\n",
    "                    titles.append(movie[\"title\"])\n",
    "                    image_ids.append(image_id)\n",
    "                    embeddings.append(embedding)\n",
    "\n",
    "    embeddings_data = {\"movie_ids\": movie_ids, \"titles\": titles, \"image_ids\": image_ids, \"image_embeddings\": embeddings}\n",
    "    dump(embeddings_data, cache_name)\n",
    "    compress_file(cache_name)\n",
    "    return embeddings_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "#embeddings = generate_image_embeddings_data()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
